### #########################################
### Calculates leave-one-out error
### Uses the error-minimizing model (not 1SE)
### Output is a table
### #########################################

here::i_am("newR/08_overtime.R")

library(here)
library(brms)
library(tidyverse)
library(abind)
library(parallel)
source(here::here("newR", "00_helpers.R"))

### Create a function we'll call repeatedly to evaluate performance
### over time.

overtime_func <- function(window,
                          df,
                          model_formula,
                          priors) {
    require(tidyverse)
    require(brms)
    source(here::here("newR", "00_helpers.R"))
    df <- df %>%
        arrange(Date)

    train <- df %>%
        filter(row_number() <= window)

    holdout <- df %>%
        filter(row_number() > window) %>%
        filter(row_number() <= (window + 10))

    mod <- brm(model_formula,
               family = dirichlet(),
               prior = priors,
               data = df,
               control = list(adapt_delta = 0.95,
                              max_treedepth = 11),
               cores = 5, chains = 5,
               seed = 43145)

    pp <- posterior_predict(mod,
                            newdata = holdout,
                            summary = FALSE)

    candvars <- paste0(c("Con", "Lab", "Lib", "Nat", "Oth"),
                       "Cand_BE")
    candmat <- as.matrix(holdout[,candvars])

### Replace adjusted outcome values
    y <- holdout$y
    y[which(candmat == 0, arr.ind = TRUE)] <- 0
    
### Zero out fitted values
    for (i in 1:dim(pp)[1]) {
        x <- pp[i,,]
        x[which(candmat == 0, arr.ind = TRUE)] <- 0
        x <- x / rowSums(x)
        pp[i,,] <- x
    }

    windowed_mae <- get_mae3(yhat = pp,
                             y = y)
    windowed_calib <- get_calibration3(yhat = pp,
                                       y = y)

    windowed_pcp <- get_pcp3(yhat = pp,
                             y = y)

    windowed_brier <- get_brier3(yhat = pp,
                                 y = y)
    
    ## Return this data
    retval <- data.frame(window = window,
                         mae = windowed_mae,
                         calib = windowed_calib,
                         pcp = windowed_pcp,
                         brier = windowed_brier)

    retval
}



dat <- readRDS(here::here("working",
                          "tidy_model_data.rds"))

load(here::here("working",
                "varsel_formulas.rdata"))

pvars <- c("Con_BE", "Lab_BE", "Lib_BE", "Nat_BE", "Oth2_BE")
for (p in pvars) { 
    dat[,p] <- replace(dat[,p],
                       dat[,p] <= 0,
                       1 / 40000)
}

### Make sure everything sums to 1
dat[,pvars] <- dat[,pvars] / rowSums(dat[,pvars])

### Create a matrix as dep. var.
dat$y <- with(dat,
              cbind(Con_BE, Lab_BE, Lib_BE, Nat_BE, Oth2_BE))

common_priors <- c(prior("normal(0, 0.429)", class = "Intercept", dpar = "muLabBE"),
               prior("normal(-0.95, 0.584)", class = "Intercept", dpar = "muLibBE"),
               prior("normal(-2.457, 0.714)", class = "Intercept", dpar = "muNatBE"),
               prior("normal(-0.95, 0.584)", class = "Intercept", dpar = "muOth2BE"),
               prior("normal(11.5, 3.25)", class = "phi"))

### Labour Poll Change is common to both models
common_priors <- c(common_priors,
               prior("normal(0.241, 0.12)", coef = "LabPollChg_sc", dpar = "muLabBE"),
               prior("normal(0.06, 0.0905)", coef = "LabPollChg_sc", dpar = "muLibBE"),
               prior("normal(0.06, 0.0905)", coef = "LabPollChg_sc", dpar = "muNatBE"),
               prior("normal(0.06, 0.0905)", coef = "LabPollChg_sc", dpar = "muOth2BE"))

### Now set up priors
### Start with the CV-minimizing model first
priors_min <- common_priors
coefs <- get_prior(f_min,
          family = dirichlet(),
          data = dat) %>%
    filter(class == "b") %>%
    filter(coef != "") %>%
    pull(coef)

coefs <- setdiff(coefs,
                 "LabPollChg_sc")

dpars <- get_prior(f_min,
          family = dirichlet(),
          data = dat) %>%
    filter(class == "b") %>%
    filter(coef != "") %>%
    pull(dpar) %>%
    unique()

for (k in coefs) {
    for (p in dpars) {
        short_party <- sub("mu(...).?BE", "\\1", p)
        ## If it's a candidacy and the it concerns the party we're modelling
        if (grepl("Cand_BE", k) & grepl(short_party, k)) {
            prior_string <- "normal(0, 1)"
        } else {
            prior_string <- "normal(0, 0.25)"
        }
        
        priors_min <- c(priors_min,
                        set_prior(prior_string, class = "b",
                                  coef = k,
                                  dpar = p))
    }
}


overtime_file <- here::here("working",
                            "overtime_errs_min.rds")

if (file.exists(overtime_file)) {
    holder <- readRDS(overtime_file)
} else {
    
### Remember, each call will spin up five cores!
    ncores <- 2
    cl <- makeCluster(ncores)
    inseq <- seq(10, nrow(dat),
                 by = 10)
    holder <- parLapply(cl, inseq, overtime_func,
                        df = dat,
                        model_formula = f_min,
                        priors = priors_min)
    stopCluster(cl)
    holder <- bind_rows(holder)
    saveRDS(holder, file = overtime_file)
}
        
